import os
import math
import time
import logging
from tqdm.auto import tqdm

import torch
import torch.nn.functional as F
from torchvision import transforms

from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers.optimization import get_scheduler

from dataset.font_dataset import FontDataset
from dataset.collate_fn import CollateFN
from configs.fontdiffuser import get_parser
from src import (
    FontDiffuserModel,
    ContentPerceptualLoss,
    build_unet,
    build_style_encoder,
    build_content_encoder,
    build_ddpm_scheduler,
    build_scr,
)
from utils import (
    save_args_to_yaml,
    x0_from_epsilon,
    reNormalize_img,
    normalize_mean_std,
)

logger = get_logger(__name__)


def get_args():
    parser = get_parser()
    args = parser.parse_args()
    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank
    style_image_size = args.style_image_size
    content_image_size = args.content_image_size
    args.style_image_size = (style_image_size, style_image_size)
    args.content_image_size = (content_image_size, content_image_size)

    return args


def main():
    args = get_args()

    logging_dir = f"{args.output_dir}/{args.logging_dir}"

    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
        log_with=args.report_to,
        project_dir=logging_dir,
    )

    if accelerator.is_main_process:
        os.makedirs(args.output_dir, exist_ok=True)

    logging.basicConfig(
        filename=f"{args.output_dir}/fontdiffuser_training.log",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )

    # Ser training seed
    if args.seed is not None:
        set_seed(args.seed)

    # Load model and noise_scheduler
    unet = build_unet(args=args)
    style_encoder = build_style_encoder(args=args)
    content_encoder = build_content_encoder(args=args)
    noise_scheduler = build_ddpm_scheduler(args)
    if args.phase_2:
        unet.load_state_dict(torch.load(f"{args.phase_1_ckpt_dir}/unet.pth"))
        style_encoder.load_state_dict(
            torch.load(f"{args.phase_1_ckpt_dir}/style_encoder.pth")
        )
        content_encoder.load_state_dict(
            torch.load(f"{args.phase_1_ckpt_dir}/content_encoder.pth")
        )

    model = FontDiffuserModel(
        unet=unet, style_encoder=style_encoder, content_encoder=content_encoder
    )

    """
    Build content perceptaual Loss
    衡量生成字体图像与目标字体图像之间的感知差异。
    计算感知损失（Perceptual Loss）的类。在深度学习中，
    感知损失是一种衡量生成图像质量的方法。
    它通过使用预训练的神经网络（如 VGG 网络）提取特征，并比较生成图像和目标图像的特征差异，而不仅仅是像素级别的差异。
    这种损失可以帮助模型生成更具感知意义的图像，使其更加接近人类的视觉体验。
    """
    perceptual_loss = ContentPerceptualLoss()

    # Load SCR module for supervision
    if args.phase_2:
        scr = build_scr(args=args)
        scr.load_state_dict(torch.load(args.scr_ckpt_path))
        scr.requires_grad_(False)

    """
    Resize: 将目标图像调整为指定的分辨率 (args.resolution)，使用双线性插值（BILINEAR）方法来进行插值
    ToTensor: 将图像从 PIL 格式转换为 PyTorch 的张量格式（Tensor），并将像素值标准化到 [0, 1] 的范围。
    Normalize: 使用均值 [0.5] 和标准差 [0.5] 对张量进行归一化，将像素值从 [0, 1] 转换为 [-1, 1] 的范围。
    """
    content_transforms = transforms.Compose(
        [
            transforms.Resize(
                args.content_image_size,
                interpolation=transforms.InterpolationMode.BILINEAR,
            ),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )
    style_transforms = transforms.Compose(
        [
            transforms.Resize(
                args.style_image_size,
                interpolation=transforms.InterpolationMode.BILINEAR,
            ),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )
    target_transforms = transforms.Compose(
        [
            transforms.Resize(
                (args.resolution, args.resolution),
                interpolation=transforms.InterpolationMode.BILINEAR,
            ),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

    train_font_dataset = FontDataset(
        args=args,
        phase="train",
        transforms=[content_transforms, style_transforms, target_transforms],
        scr=args.phase_2,
    )
    train_dataloader = torch.utils.data.DataLoader(
        train_font_dataset,
        shuffle=True,
        batch_size=args.train_batch_size,
        collate_fn=CollateFN(),
    )

    # Build optimizer and learning rate
    if args.scale_lr:
        args.learning_rate = (
            args.learning_rate
            * args.gradient_accumulation_steps
            * args.train_batch_size
            * accelerator.num_processes
        )
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.learning_rate,
        # beta1 和 beta2 是 Adam 优化器中的两个动量项，控制一阶和二阶动量。beta1 通常用于加速收敛，beta2 用于防止震荡。
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )
    # 学习率调度器 动态调整学习率，避免训练过程中过早或过晚收敛，从而提升模型性能。
    lr_scheduler = get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
    )

    # Accelerate preparation
    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )
    ## move scr module to the target deivces
    if args.phase_2:
        scr = scr.to(accelerator.device)

    # The trackers initialize automatically on the main process.
    if accelerator.is_main_process:
        # Accelerate 库中的一个属性，用于判断当前进程是否为主要进程（主进程）。
        # 在分布式训练中，通常会有多个进程同时运行来加速模型训练。其中，只有一个进程是“主进程”（main process）主要负责记录日志、保存模型等操作
        # 而其他进程则主要用于计算。
        accelerator.init_trackers(args.experience_name)
        save_args_to_yaml(
            args=args,
            output_file=f"{args.output_dir}/{args.experience_name}_config.yaml",
        )

    # Only show the progress bar once on each machine.
    progress_bar = tqdm(
        range(args.max_train_steps), disable=not accelerator.is_local_main_process
    )
    progress_bar.set_description("Steps")

    # Convert to the training epoch
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / args.gradient_accumulation_steps
    )
    num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    global_step = 0
    for epoch in range(num_train_epochs):
        train_loss = 0.0
        for step, samples in enumerate(train_dataloader):
            model.train()
            content_images = samples["content_image"]
            style_images = samples["style_image"]
            target_images = samples["target_image"]
            nonorm_target_images = samples["nonorm_target_image"]

            with accelerator.accumulate(model):
                # Sample noise that we'll add to the samples
                noise = torch.randn_like(target_images)
                bsz = target_images.shape[0]
                # Sample a random timestep for each image
                timesteps = torch.randint(
                    0,
                    noise_scheduler.num_train_timesteps,
                    (bsz,),
                    device=target_images.device,
                )
                timesteps = timesteps.long()

                # Add noise to the target_images according to the noise magnitude at each timestep
                # (this is the forward diffusion process) 扩散模型的核心，给目标图像添加噪声。
                noisy_target_images = noise_scheduler.add_noise(
                    target_images, noise, timesteps
                )

                # Classifier-free training strategy 分类器自由训练策略
                # 随机遮蔽风格和内容图像，用于实现分类器自由引导。
                context_mask = torch.bernoulli(torch.zeros(bsz) + args.drop_prob)
                for i, mask_value in enumerate(context_mask):
                    if mask_value == 1:
                        content_images[i, :, :, :] = 1
                        style_images[i, :, :, :] = 1

                # Predict the noise residual and compute loss
                noise_pred, offset_out_sum = model(
                    x_t=noisy_target_images,
                    timesteps=timesteps,
                    style_images=style_images,
                    content_images=content_images,
                    content_encoder_downsample_size=args.content_encoder_downsample_size,
                )
                # 通过模型预测噪声，并计算与真实噪声之间的均方误差（MSE）
                diff_loss = F.mse_loss(
                    noise_pred.float(), noise.float(), reduction="mean"
                )
                offset_loss = offset_out_sum / 2

                # output processing for content perceptual loss
                pred_original_sample_norm = x0_from_epsilon(
                    scheduler=noise_scheduler,
                    noise_pred=noise_pred,
                    x_t=noisy_target_images,
                    timesteps=timesteps,
                )
                pred_original_sample = reNormalize_img(pred_original_sample_norm)
                norm_pred_ori = normalize_mean_std(pred_original_sample)
                norm_target_ori = normalize_mean_std(nonorm_target_images)
                percep_loss = perceptual_loss.calculate_loss(
                    generated_images=norm_pred_ori,
                    target_images=norm_target_ori,
                    device=target_images.device,
                )

                loss = (
                    diff_loss
                    + args.perceptual_coefficient * percep_loss
                    + args.offset_coefficient * offset_loss
                )

                if args.phase_2:
                    neg_images = samples["neg_images"]
                    # sc loss
                    (
                        sample_style_embeddings,
                        pos_style_embeddings,
                        neg_style_embeddings,
                    ) = scr(
                        pred_original_sample_norm,
                        target_images,
                        neg_images,
                        nce_layers=args.nce_layers,
                    )
                    sc_loss = scr.calculate_nce_loss(
                        sample_s=sample_style_embeddings,
                        pos_s=pos_style_embeddings,
                        neg_s=neg_style_embeddings,
                    )
                    loss += args.sc_coefficient * sc_loss

                # Gather the losses across all processes for logging (if we use distributed training).
                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
                train_loss += avg_loss.item() / args.gradient_accumulation_steps

                # Backpropagate
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            # Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1
                accelerator.log({"train_loss": train_loss}, step=global_step)
                train_loss = 0.0

                if accelerator.is_main_process:
                    if global_step % args.ckpt_interval == 0:
                        save_dir = f"{args.output_dir}/global_step_{global_step}"
                        os.makedirs(save_dir, exist_ok=True)
                        # 此处如果 accelerate config 设置的多卡,需要加 module
                        torch.save(
                            model.module.unet.state_dict(), f"{save_dir}/unet.pth"
                        )
                        torch.save(
                            model.module.style_encoder.state_dict(),
                            f"{save_dir}/style_encoder.pth",
                        )
                        torch.save(
                            model.module.content_encoder.state_dict(),
                            f"{save_dir}/content_encoder.pth",
                        )
                        torch.save(model, f"{save_dir}/total_model.pth")
                        logging.info(
                            f"[{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))}] Save the checkpoint on global step {global_step}"
                        )
                        print(
                            "Save the checkpoint on global step {}".format(global_step)
                        )

            logs = {
                "step_loss": loss.detach().item(),
                "lr": lr_scheduler.get_last_lr()[0],
            }
            if global_step % args.log_interval == 0:
                logging.info(
                    f"[{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))}] Global Step {global_step} => train_loss = {loss}"
                )
            progress_bar.set_postfix(**logs)

            # Quit
            if global_step >= args.max_train_steps:
                break

    accelerator.end_training()


if __name__ == "__main__":
    main()
